{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "efaabde6",
   "metadata": {
    "origin_pos": 1
   },
   "source": [
    "# Attention Scoring Functions\n",
    ":label:`sec_attention-scoring-functions`\n",
    "\n",
    "\n",
    "In :numref:`sec_attention-pooling`,\n",
    "we used a number of different distance-based kernels, including a Gaussian kernel to model\n",
    "interactions between queries and keys. As it turns out, distance functions are slightly more expensive to compute than dot products. As such, \n",
    "with the softmax operation to ensure nonnegative attention weights,\n",
    "much of the work has gone into *attention scoring functions* $a$ in :eqref:`eq_softmax_attention` and :numref:`fig_attention_output` that are simpler to compute. \n",
    "\n",
    "![Computing the output of attention pooling as a weighted average of values, where weights are computed with the attention scoring function $\\mathit{a}$ and the softmax operation.](../img/attention-output.svg)\n",
    ":label:`fig_attention_output`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8e33a108",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:45.433072Z",
     "iopub.status.busy": "2023-08-18T19:43:45.432523Z",
     "iopub.status.idle": "2023-08-18T19:43:48.504425Z",
     "shell.execute_reply": "2023-08-18T19:43:48.503548Z"
    },
    "origin_pos": 3,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8224956",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "## [**Dot Product Attention**]\n",
    "\n",
    "\n",
    "Let's review the attention function (without exponentiation) from the Gaussian kernel for a moment:\n",
    "\n",
    "$$\n",
    "a(\\mathbf{q}, \\mathbf{k}_i) = -\\frac{1}{2} \\|\\mathbf{q} - \\mathbf{k}_i\\|^2  = \\mathbf{q}^\\top \\mathbf{k}_i -\\frac{1}{2} \\|\\mathbf{k}_i\\|^2  -\\frac{1}{2} \\|\\mathbf{q}\\|^2.\n",
    "$$\n",
    "\n",
    "First, note that the final term depends on $\\mathbf{q}$ only. As such it is identical for all $(\\mathbf{q}, \\mathbf{k}_i)$ pairs. Normalizing the attention weights to $1$, as is done in :eqref:`eq_softmax_attention`, ensures that this term disappears entirely. Second, note that both batch and layer normalization (to be discussed later) lead to activations that have well-bounded, and often constant, norms $\\|\\mathbf{k}_i\\|$. This is the case, for instance, whenever the keys $\\mathbf{k}_i$ were generated by a layer norm. As such, we can drop it from the definition of $a$ without any major change in the outcome. \n",
    "\n",
    "Last, we need to keep the order of magnitude of the arguments in the exponential function under control. Assume that all the elements of the query $\\mathbf{q} \\in \\mathbb{R}^d$ and the key $\\mathbf{k}_i \\in \\mathbb{R}^d$ are independent and identically drawn random variables with zero mean and unit variance. The dot product between both vectors has zero mean and a variance of $d$. To ensure that the variance of the dot product still remains $1$ regardless of vector length, we use the *scaled dot product attention* scoring function. That is, we rescale the dot product by $1/\\sqrt{d}$. We thus arrive at the first commonly used attention function that is used, e.g., in Transformers :cite:`Vaswani.Shazeer.Parmar.ea.2017`:\n",
    "\n",
    "$$ a(\\mathbf{q}, \\mathbf{k}_i) = \\mathbf{q}^\\top \\mathbf{k}_i / \\sqrt{d}.$$\n",
    ":eqlabel:`eq_dot_product_attention`\n",
    "\n",
    "Note that attention weights $\\alpha$ still need normalizing. We can simplify this further via :eqref:`eq_softmax_attention` by using the softmax operation:\n",
    "\n",
    "$$\\alpha(\\mathbf{q}, \\mathbf{k}_i) = \\mathrm{softmax}(a(\\mathbf{q}, \\mathbf{k}_i)) = \\frac{\\exp(\\mathbf{q}^\\top \\mathbf{k}_i / \\sqrt{d})}{\\sum_{j=1} \\exp(\\mathbf{q}^\\top \\mathbf{k}_j / \\sqrt{d})}.$$\n",
    ":eqlabel:`eq_attn-scoring-alpha`\n",
    "\n",
    "As it turns out, all popular attention mechanisms use the softmax, hence we will limit ourselves to that in the remainder of this chapter.\n",
    "\n",
    "## Convenience Functions\n",
    "\n",
    "We need a few functions to make the attention mechanism efficient to deploy. This includes tools for dealing with strings of variable lengths (common for natural language processing) and tools for efficient evaluation on minibatches (batch matrix multiplication). \n",
    "\n",
    "\n",
    "### [**Masked Softmax Operation**]\n",
    "\n",
    "One of the most popular applications of the attention mechanism is to sequence models. Hence we need to be able to deal with sequences of different lengths. In some cases, such sequences may end up in the same minibatch, necessitating padding with dummy tokens for shorter sequences (see :numref:`sec_machine_translation` for an example). These special tokens do not carry meaning. For instance, assume that we have the following three sentences:\n",
    "\n",
    "```\n",
    "Dive  into  Deep    Learning \n",
    "Learn to    code    <blank>\n",
    "Hello world <blank> <blank>\n",
    "```\n",
    "\n",
    "Since we do not want blanks in our attention model we simply need to limit $\\sum_{i=1}^n \\alpha(\\mathbf{q}, \\mathbf{k}_i) \\mathbf{v}_i$ to $\\sum_{i=1}^l \\alpha(\\mathbf{q}, \\mathbf{k}_i) \\mathbf{v}_i$ for however long, $l \\leq n$, the actual sentence is. Since it is such a common problem, it has a name: the *masked softmax operation*. \n",
    "\n",
    "Let's implement it. Actually, the implementation cheats ever so slightly by setting the values of $\\mathbf{v}_i$, for $i > l$, to zero. Moreover, it sets the attention weights to a large negative number, such as $-10^{6}$, in order to make their contribution to gradients and values vanish in practice. This is done since linear algebra kernels and operators are heavily optimized for GPUs and it is faster to be slightly wasteful in computation rather than to have code with conditional (if then else) statements.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "080c4919",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.508521Z",
     "iopub.status.busy": "2023-08-18T19:43:48.507880Z",
     "iopub.status.idle": "2023-08-18T19:43:48.515032Z",
     "shell.execute_reply": "2023-08-18T19:43:48.514260Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def masked_softmax(X, valid_lens):  #@save\n",
    "    \"\"\"Perform softmax operation by masking elements on the last axis.\"\"\"\n",
    "    # X: 3D tensor, valid_lens: 1D or 2D tensor\n",
    "    def _sequence_mask(X, valid_len, value=0):\n",
    "        maxlen = X.size(1)\n",
    "        mask = torch.arange((maxlen), dtype=torch.float32,\n",
    "                            device=X.device)[None, :] < valid_len[:, None]\n",
    "        X[~mask] = value\n",
    "        return X\n",
    "\n",
    "    if valid_lens is None:\n",
    "        return nn.functional.softmax(X, dim=-1)\n",
    "    else:\n",
    "        shape = X.shape\n",
    "        if valid_lens.dim() == 1:\n",
    "            valid_lens = torch.repeat_interleave(valid_lens, shape[1])\n",
    "        else:\n",
    "            valid_lens = valid_lens.reshape(-1)\n",
    "        # On the last axis, replace masked elements with a very large negative\n",
    "        # value, whose exponentiation outputs 0\n",
    "        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)\n",
    "        return nn.functional.softmax(X.reshape(shape), dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac9441c0",
   "metadata": {
    "origin_pos": 11
   },
   "source": [
    "To [**illustrate how this function works**],\n",
    "consider a minibatch of two examples of size $2 \\times 4$,\n",
    "where their valid lengths are $2$ and $3$, respectively. \n",
    "As a result of the masked softmax operation,\n",
    "values beyond the valid lengths for each pair of vectors are all masked as zero.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b0fb493b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.518456Z",
     "iopub.status.busy": "2023-08-18T19:43:48.517778Z",
     "iopub.status.idle": "2023-08-18T19:43:48.554108Z",
     "shell.execute_reply": "2023-08-18T19:43:48.553283Z"
    },
    "origin_pos": 13,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.4448, 0.5552, 0.0000, 0.0000],\n",
       "         [0.4032, 0.5968, 0.0000, 0.0000]],\n",
       "\n",
       "        [[0.2795, 0.2805, 0.4400, 0.0000],\n",
       "         [0.2798, 0.3092, 0.4110, 0.0000]]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "390fd5a8",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "If we need more fine-grained control to specify the valid length for each of the two vectors of every example, we simply use a two-dimensional tensor of valid lengths. This yields:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0eff10c9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.557828Z",
     "iopub.status.busy": "2023-08-18T19:43:48.557262Z",
     "iopub.status.idle": "2023-08-18T19:43:48.564098Z",
     "shell.execute_reply": "2023-08-18T19:43:48.563239Z"
    },
    "origin_pos": 18,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n",
       "         [0.4109, 0.2794, 0.3097, 0.0000]],\n",
       "\n",
       "        [[0.3960, 0.6040, 0.0000, 0.0000],\n",
       "         [0.2557, 0.1833, 0.2420, 0.3190]]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a03f10da",
   "metadata": {
    "origin_pos": 21
   },
   "source": [
    "### Batch Matrix Multiplication\n",
    ":label:`subsec_batch_dot`\n",
    "\n",
    "Another commonly used operation is to multiply batches of matrices by one another. This comes in handy when we have minibatches of queries, keys, and values. More specifically, assume that \n",
    "\n",
    "$$\\mathbf{Q} = [\\mathbf{Q}_1, \\mathbf{Q}_2, \\ldots, \\mathbf{Q}_n]  \\in \\mathbb{R}^{n \\times a \\times b}, \\\\\n",
    "    \\mathbf{K} = [\\mathbf{K}_1, \\mathbf{K}_2, \\ldots, \\mathbf{K}_n]  \\in \\mathbb{R}^{n \\times b \\times c}.\n",
    "$$\n",
    "\n",
    "Then the batch matrix multiplication (BMM) computes the elementwise product\n",
    "\n",
    "$$\\textrm{BMM}(\\mathbf{Q}, \\mathbf{K}) = [\\mathbf{Q}_1 \\mathbf{K}_1, \\mathbf{Q}_2 \\mathbf{K}_2, \\ldots, \\mathbf{Q}_n \\mathbf{K}_n] \\in \\mathbb{R}^{n \\times a \\times c}.$$\n",
    ":eqlabel:`eq_batch-matrix-mul`\n",
    "\n",
    "Let's see this in action in a deep learning framework.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1d592456",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.567605Z",
     "iopub.status.busy": "2023-08-18T19:43:48.567037Z",
     "iopub.status.idle": "2023-08-18T19:43:48.572146Z",
     "shell.execute_reply": "2023-08-18T19:43:48.571131Z"
    },
    "origin_pos": 23,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "Q = torch.ones((2, 3, 4))\n",
    "K = torch.ones((2, 4, 6))\n",
    "d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aef7274",
   "metadata": {
    "origin_pos": 26
   },
   "source": [
    "## [**Scaled Dot Product Attention**]\n",
    "\n",
    "Let's return to the dot product attention introduced in :eqref:`eq_dot_product_attention`. \n",
    "In general, it requires that both the query and the key\n",
    "have the same vector length, say $d$, even though this can be addressed easily by replacing \n",
    "$\\mathbf{q}^\\top \\mathbf{k}$ with $\\mathbf{q}^\\top \\mathbf{M} \\mathbf{k}$ where $\\mathbf{M}$ is a matrix suitably chosen for translating between both spaces. For now assume that the dimensions match. \n",
    "\n",
    "In practice, we often think of minibatches for efficiency,\n",
    "such as computing attention for $n$ queries and $m$ key-value pairs,\n",
    "where queries and keys are of length $d$\n",
    "and values are of length $v$. The scaled dot product attention \n",
    "of queries $\\mathbf Q\\in\\mathbb R^{n\\times d}$,\n",
    "keys $\\mathbf K\\in\\mathbb R^{m\\times d}$,\n",
    "and values $\\mathbf V\\in\\mathbb R^{m\\times v}$\n",
    "thus can be written as \n",
    "\n",
    "$$ \\mathrm{softmax}\\left(\\frac{\\mathbf Q \\mathbf K^\\top }{\\sqrt{d}}\\right) \\mathbf V \\in \\mathbb{R}^{n\\times v}.$$\n",
    ":eqlabel:`eq_softmax_QK_V`\n",
    "\n",
    "Note that when applying this to a minibatch, we need the batch matrix multiplication introduced in :eqref:`eq_batch-matrix-mul`. In the following implementation of the scaled dot product attention,\n",
    "we use dropout for model regularization.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "33207d5f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.575743Z",
     "iopub.status.busy": "2023-08-18T19:43:48.575036Z",
     "iopub.status.idle": "2023-08-18T19:43:48.581055Z",
     "shell.execute_reply": "2023-08-18T19:43:48.580209Z"
    },
    "origin_pos": 28,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class DotProductAttention(nn.Module):  #@save\n",
    "    \"\"\"Scaled dot product attention.\"\"\"\n",
    "    def __init__(self, dropout):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    # Shape of queries: (batch_size, no. of queries, d)\n",
    "    # Shape of keys: (batch_size, no. of key-value pairs, d)\n",
    "    # Shape of values: (batch_size, no. of key-value pairs, value dimension)\n",
    "    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)\n",
    "    def forward(self, queries, keys, values, valid_lens=None):\n",
    "        d = queries.shape[-1]\n",
    "        # Swap the last two dimensions of keys with keys.transpose(1, 2)\n",
    "        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)\n",
    "        self.attention_weights = masked_softmax(scores, valid_lens)\n",
    "        return torch.bmm(self.dropout(self.attention_weights), values)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbb8df37",
   "metadata": {
    "origin_pos": 31
   },
   "source": [
    "To [**illustrate how the `DotProductAttention` class works**],\n",
    "we use the same keys, values, and valid lengths from the earlier toy example for additive attention. For the purpose of our example we assume that we have a minibatch size of $2$, a total of $10$ keys and values, and that the dimensionality of the values is $4$. Lastly, we assume that the valid length per observation is $2$ and $6$ respectively. Given that, we expect the output to be a $2 \\times 1 \\times 4$ tensor, i.e., one row per example of the minibatch.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2f449209",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.584794Z",
     "iopub.status.busy": "2023-08-18T19:43:48.584072Z",
     "iopub.status.idle": "2023-08-18T19:43:48.590854Z",
     "shell.execute_reply": "2023-08-18T19:43:48.589996Z"
    },
    "origin_pos": 33,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "queries = torch.normal(0, 1, (2, 1, 2))\n",
    "keys = torch.normal(0, 1, (2, 10, 2))\n",
    "values = torch.normal(0, 1, (2, 10, 4))\n",
    "valid_lens = torch.tensor([2, 6])\n",
    "\n",
    "attention = DotProductAttention(dropout=0.5)\n",
    "attention.eval()\n",
    "d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00e17255",
   "metadata": {
    "origin_pos": 36
   },
   "source": [
    "Let's check whether the attention weights actually vanish for anything beyond the second and sixth column respectively (because of setting the valid length to $2$ and $6$).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f40e370d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.594461Z",
     "iopub.status.busy": "2023-08-18T19:43:48.593898Z",
     "iopub.status.idle": "2023-08-18T19:43:48.969221Z",
     "shell.execute_reply": "2023-08-18T19:43:48.968308Z"
    },
    "origin_pos": 37,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"187.07675pt\" height=\"103.438906pt\" viewBox=\"0 0 187.07675 103.438906\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T19:43:48.906167</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M -0 103.438906 \n",
       "L 187.07675 103.438906 \n",
       "L 187.07675 0 \n",
       "L -0 0 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 34.240625 59.94 \n",
       "L 145.840625 59.94 \n",
       "L 145.840625 37.62 \n",
       "L 34.240625 37.62 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g clip-path=\"url(#pfcdd75fbbc)\">\n",
       "    <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAJsAAAAfCAYAAADwQL9CAAAAuUlEQVR4nO3aoQ3CUBiF0RaC6wBoHHMhmAASFsAyBw5NGKFhFwQOWgwTvJdc0nCOv/0rvlS1fd+vY1NouF1Kp98HDMXT+eZQd/v5KJ6+jruq04vTuWo/VbNfvwD/Q2zEiI0YsREjNmLERozYiBEbMWIjRmzEiI0YsREjNmLERozYiGm3TVf8P9t+vaw6vur7qj3T4stGjNiIERsxYiNGbMSIjRixESM2YsRGjNiIERsxYiNGbMSIjZgP5zgSA1PqCxQAAAAASUVORK5CYII=\" id=\"imagebab9152e0a\" transform=\"scale(1 -1) translate(0 -22.32)\" x=\"34.240625\" y=\"-37.62\" width=\"111.6\" height=\"22.32\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path id=\"m54afa474af\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m54afa474af\" x=\"39.820625\" y=\"59.94\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(36.639375 74.538438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m54afa474af\" x=\"95.620625\" y=\"59.94\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(92.439375 74.538438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_3\">\n",
       "     <!-- Keys -->\n",
       "     <g transform=\"translate(78.371094 88.216563) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-4b\" d=\"M 628 4666 \n",
       "L 1259 4666 \n",
       "L 1259 2694 \n",
       "L 3353 4666 \n",
       "L 4166 4666 \n",
       "L 1850 2491 \n",
       "L 4331 0 \n",
       "L 3500 0 \n",
       "L 1259 2247 \n",
       "L 1259 0 \n",
       "L 628 0 \n",
       "L 628 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
       "Q 1816 -950 1584 -1140 \n",
       "Q 1353 -1331 966 -1331 \n",
       "L 506 -1331 \n",
       "L 506 -850 \n",
       "L 844 -850 \n",
       "Q 1081 -850 1212 -737 \n",
       "Q 1344 -625 1503 -206 \n",
       "L 1606 56 \n",
       "L 191 3500 \n",
       "L 800 3500 \n",
       "L 1894 763 \n",
       "L 2988 3500 \n",
       "L 3597 3500 \n",
       "L 2059 -325 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
       "L 2834 2853 \n",
       "Q 2591 2978 2328 3040 \n",
       "Q 2066 3103 1784 3103 \n",
       "Q 1356 3103 1142 2972 \n",
       "Q 928 2841 928 2578 \n",
       "Q 928 2378 1081 2264 \n",
       "Q 1234 2150 1697 2047 \n",
       "L 1894 2003 \n",
       "Q 2506 1872 2764 1633 \n",
       "Q 3022 1394 3022 966 \n",
       "Q 3022 478 2636 193 \n",
       "Q 2250 -91 1575 -91 \n",
       "Q 1294 -91 989 -36 \n",
       "Q 684 19 347 128 \n",
       "L 347 722 \n",
       "Q 666 556 975 473 \n",
       "Q 1284 391 1588 391 \n",
       "Q 1994 391 2212 530 \n",
       "Q 2431 669 2431 922 \n",
       "Q 2431 1156 2273 1281 \n",
       "Q 2116 1406 1581 1522 \n",
       "L 1381 1569 \n",
       "Q 847 1681 609 1914 \n",
       "Q 372 2147 372 2553 \n",
       "Q 372 3047 722 3315 \n",
       "Q 1072 3584 1716 3584 \n",
       "Q 2034 3584 2315 3537 \n",
       "Q 2597 3491 2834 3397 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-4b\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"60.576172\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-79\" x=\"122.099609\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"181.279297\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <defs>\n",
       "       <path id=\"mfffe0939f9\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#mfffe0939f9\" x=\"34.240625\" y=\"43.2\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(20.878125 46.999219) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mfffe0939f9\" x=\"34.240625\" y=\"54.36\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 1 -->\n",
       "      <g transform=\"translate(20.878125 58.159219) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- Queries -->\n",
       "     <g transform=\"translate(14.798437 68.087031) rotate(-90) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-51\" d=\"M 2522 4238 \n",
       "Q 1834 4238 1429 3725 \n",
       "Q 1025 3213 1025 2328 \n",
       "Q 1025 1447 1429 934 \n",
       "Q 1834 422 2522 422 \n",
       "Q 3209 422 3611 934 \n",
       "Q 4013 1447 4013 2328 \n",
       "Q 4013 3213 3611 3725 \n",
       "Q 3209 4238 2522 4238 \n",
       "z\n",
       "M 3406 84 \n",
       "L 4238 -825 \n",
       "L 3475 -825 \n",
       "L 2784 -78 \n",
       "Q 2681 -84 2626 -87 \n",
       "Q 2572 -91 2522 -91 \n",
       "Q 1538 -91 948 567 \n",
       "Q 359 1225 359 2328 \n",
       "Q 359 3434 948 4092 \n",
       "Q 1538 4750 2522 4750 \n",
       "Q 3503 4750 4090 4092 \n",
       "Q 4678 3434 4678 2328 \n",
       "Q 4678 1516 4351 937 \n",
       "Q 4025 359 3406 84 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-75\" d=\"M 544 1381 \n",
       "L 544 3500 \n",
       "L 1119 3500 \n",
       "L 1119 1403 \n",
       "Q 1119 906 1312 657 \n",
       "Q 1506 409 1894 409 \n",
       "Q 2359 409 2629 706 \n",
       "Q 2900 1003 2900 1516 \n",
       "L 2900 3500 \n",
       "L 3475 3500 \n",
       "L 3475 0 \n",
       "L 2900 0 \n",
       "L 2900 538 \n",
       "Q 2691 219 2414 64 \n",
       "Q 2138 -91 1772 -91 \n",
       "Q 1169 -91 856 284 \n",
       "Q 544 659 544 1381 \n",
       "z\n",
       "M 1991 3584 \n",
       "L 1991 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-51\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-75\" x=\"78.710938\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"142.089844\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"203.613281\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"244.726562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"272.509766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"334.033203\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 34.240625 59.94 \n",
       "L 34.240625 37.62 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 145.840625 59.94 \n",
       "L 145.840625 37.62 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 34.240625 59.94 \n",
       "L 145.840625 59.94 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 34.240625 37.62 \n",
       "L 145.840625 37.62 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       "  <g id=\"axes_2\">\n",
       "   <g id=\"patch_7\">\n",
       "    <path d=\"M 152.815625 90.36 \n",
       "L 156.973625 90.36 \n",
       "L 156.973625 7.2 \n",
       "L 152.815625 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAAYAAAB0CAYAAACmJkOCAAAAyElEQVR4nM2WSwrEMAxDXcj9zzqbblp/egG/gEII02WEJMt2Qq+6f2XNN6yyO7dhGQQQo1qLFY9ij30AB5Rz6B7uIqNQKkiKAUwuAyRV8e6SmjTxgAcy5An+qTkD+vrQNVgw15sI5qWXi909kLwCpTCgDPgB83dfuXpVFDAD3vakR18H3KEqWJKJVCBAyTEgSukBw2WP9thsJP0zOAEoFcjoz22EUQ5kyFVhwH5BZx4MkBQ2Ee7mhAEjX2E88sxRCluyYg4MquoD0XoQUGMmzEcAAAAASUVORK5CYII=\" id=\"image1835bd6c16\" transform=\"scale(1 -1) translate(0 -83.52)\" x=\"152.64\" y=\"-6.48\" width=\"4.32\" height=\"83.52\"/>\n",
       "   <g id=\"matplotlib.axis_3\"/>\n",
       "   <g id=\"matplotlib.axis_4\">\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <defs>\n",
       "       <path id=\"mc813728452\" d=\"M 0 0 \n",
       "L 3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#mc813728452\" x=\"156.973625\" y=\"90.36\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.0 -->\n",
       "      <g transform=\"translate(163.973625 94.159219) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
       "L 1344 794 \n",
       "L 1344 0 \n",
       "L 684 0 \n",
       "L 684 794 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mc813728452\" x=\"156.973625\" y=\"62.355792\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.2 -->\n",
       "      <g transform=\"translate(163.973625 66.155011) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mc813728452\" x=\"156.973625\" y=\"34.351585\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.4 -->\n",
       "      <g transform=\"translate(163.973625 38.150804) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
       "L 825 1625 \n",
       "L 2419 1625 \n",
       "L 2419 4116 \n",
       "z\n",
       "M 2253 4666 \n",
       "L 3047 4666 \n",
       "L 3047 1625 \n",
       "L 3713 1625 \n",
       "L 3713 1100 \n",
       "L 3047 1100 \n",
       "L 3047 0 \n",
       "L 2419 0 \n",
       "L 2419 1100 \n",
       "L 313 1100 \n",
       "L 313 1709 \n",
       "L 2253 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-34\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_1\"/>\n",
       "   <g id=\"patch_8\">\n",
       "    <path d=\"M 152.815625 90.36 \n",
       "L 154.894625 90.36 \n",
       "L 156.973625 90.36 \n",
       "L 156.973625 7.2 \n",
       "L 154.894625 7.2 \n",
       "L 152.815625 7.2 \n",
       "L 152.815625 90.36 \n",
       "z\n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pfcdd75fbbc\">\n",
       "   <rect x=\"34.240625\" y=\"37.62\" width=\"111.6\" height=\"22.32\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 250x250 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
    "                  xlabel='Keys', ylabel='Queries')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "583e7b8d",
   "metadata": {
    "origin_pos": 39
   },
   "source": [
    "## [**Additive Attention**]\n",
    ":label:`subsec_additive-attention`\n",
    "\n",
    "When queries $\\mathbf{q}$ and keys $\\mathbf{k}$ are vectors of different dimension,\n",
    "we can either use a matrix to address the mismatch via $\\mathbf{q}^\\top \\mathbf{M} \\mathbf{k}$, or we can use additive attention \n",
    "as the scoring function. Another benefit is that, as its name indicates, the attention is additive. This can lead to some minor computational savings. \n",
    "Given a query $\\mathbf{q} \\in \\mathbb{R}^q$\n",
    "and a key $\\mathbf{k} \\in \\mathbb{R}^k$,\n",
    "the *additive attention* scoring function :cite:`Bahdanau.Cho.Bengio.2014` is given by \n",
    "\n",
    "$$a(\\mathbf q, \\mathbf k) = \\mathbf w_v^\\top \\textrm{tanh}(\\mathbf W_q\\mathbf q + \\mathbf W_k \\mathbf k) \\in \\mathbb{R},$$\n",
    ":eqlabel:`eq_additive-attn`\n",
    "\n",
    "where $\\mathbf W_q\\in\\mathbb R^{h\\times q}$, $\\mathbf W_k\\in\\mathbb R^{h\\times k}$, \n",
    "and $\\mathbf w_v\\in\\mathbb R^{h}$ are the learnable parameters. This term is then fed into a softmax to ensure both nonnegativity and normalization. \n",
    "An equivalent interpretation of :eqref:`eq_additive-attn` is that the query and key are concatenated\n",
    "and fed into an MLP with a single hidden layer. \n",
    "Using $\\tanh$ as the activation function and disabling bias terms, \n",
    "we implement additive attention as follows:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3a2e6dee",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.973108Z",
     "iopub.status.busy": "2023-08-18T19:43:48.972388Z",
     "iopub.status.idle": "2023-08-18T19:43:48.979819Z",
     "shell.execute_reply": "2023-08-18T19:43:48.978914Z"
    },
    "origin_pos": 41,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class AdditiveAttention(nn.Module):  #@save\n",
    "    \"\"\"Additive attention.\"\"\"\n",
    "    def __init__(self, num_hiddens, dropout, **kwargs):\n",
    "        super(AdditiveAttention, self).__init__(**kwargs)\n",
    "        self.W_k = nn.LazyLinear(num_hiddens, bias=False)\n",
    "        self.W_q = nn.LazyLinear(num_hiddens, bias=False)\n",
    "        self.w_v = nn.LazyLinear(1, bias=False)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, queries, keys, values, valid_lens):\n",
    "        queries, keys = self.W_q(queries), self.W_k(keys)\n",
    "        # After dimension expansion, shape of queries: (batch_size, no. of\n",
    "        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of\n",
    "        # key-value pairs, num_hiddens). Sum them up with broadcasting\n",
    "        features = queries.unsqueeze(2) + keys.unsqueeze(1)\n",
    "        features = torch.tanh(features)\n",
    "        # There is only one output of self.w_v, so we remove the last\n",
    "        # one-dimensional entry from the shape. Shape of scores: (batch_size,\n",
    "        # no. of queries, no. of key-value pairs)\n",
    "        scores = self.w_v(features).squeeze(-1)\n",
    "        self.attention_weights = masked_softmax(scores, valid_lens)\n",
    "        # Shape of values: (batch_size, no. of key-value pairs, value\n",
    "        # dimension)\n",
    "        return torch.bmm(self.dropout(self.attention_weights), values)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c313394",
   "metadata": {
    "origin_pos": 44
   },
   "source": [
    "Let's [**see how `AdditiveAttention` works**]. In our toy example we pick queries, keys and values of size \n",
    "$(2, 1, 20)$, $(2, 10, 2)$ and $(2, 10, 4)$, respectively. This is identical to our choice for `DotProductAttention`, except that now the queries are $20$-dimensional. Likewise, we pick $(2, 6)$ as the valid lengths for the sequences in the minibatch.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c1e66c95",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.983249Z",
     "iopub.status.busy": "2023-08-18T19:43:48.982715Z",
     "iopub.status.idle": "2023-08-18T19:43:48.993364Z",
     "shell.execute_reply": "2023-08-18T19:43:48.992407Z"
    },
    "origin_pos": 46,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "queries = torch.normal(0, 1, (2, 1, 20))\n",
    "\n",
    "attention = AdditiveAttention(num_hiddens=8, dropout=0.1)\n",
    "attention.eval()\n",
    "d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0f37c23",
   "metadata": {
    "origin_pos": 49
   },
   "source": [
    "When reviewing the attention function we see a behavior that is qualitatively quite similar to that of `DotProductAttention`. That is, only terms within the chosen valid length $(2, 6)$ are nonzero.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bf7a330b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:48.996815Z",
     "iopub.status.busy": "2023-08-18T19:43:48.996248Z",
     "iopub.status.idle": "2023-08-18T19:43:49.212301Z",
     "shell.execute_reply": "2023-08-18T19:43:49.211395Z"
    },
    "origin_pos": 50,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"187.07675pt\" height=\"103.438906pt\" viewBox=\"0 0 187.07675 103.438906\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T19:43:49.170875</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M -0 103.438906 \n",
       "L 187.07675 103.438906 \n",
       "L 187.07675 0 \n",
       "L -0 0 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 34.240625 59.94 \n",
       "L 145.840625 59.94 \n",
       "L 145.840625 37.62 \n",
       "L 34.240625 37.62 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g clip-path=\"url(#pd589a9b0d6)\">\n",
       "    <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAJsAAAAfCAYAAADwQL9CAAAAvElEQVR4nO3asQkCQRRF0RkxNDGzEFMbsA7BAuzDMmxjIyswNjAQA2ErENYeZuDJ4jn5409w2WjrZ7hMpdX4bp6WUkp5PpqndbfvOj2Nr/btdei6vTydu/Zztfj1A/gfYiNGbMSIjRixESM2YsRGjNiIERsxYiNGbMSIjRixESM2YsRGTD2WVfP/bIfNuuv49n7r2jMvvmzEiI0YsREjNmLERozYiBEbMWIjRmzEiI0YsREjNmLERozYiPkCGcQSA8oz1BMAAAAASUVORK5CYII=\" id=\"imagee577720131\" transform=\"scale(1 -1) translate(0 -22.32)\" x=\"34.240625\" y=\"-37.62\" width=\"111.6\" height=\"22.32\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path id=\"m1ee9d76aa4\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m1ee9d76aa4\" x=\"39.820625\" y=\"59.94\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(36.639375 74.538438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m1ee9d76aa4\" x=\"95.620625\" y=\"59.94\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(92.439375 74.538438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_3\">\n",
       "     <!-- Keys -->\n",
       "     <g transform=\"translate(78.371094 88.216563) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-4b\" d=\"M 628 4666 \n",
       "L 1259 4666 \n",
       "L 1259 2694 \n",
       "L 3353 4666 \n",
       "L 4166 4666 \n",
       "L 1850 2491 \n",
       "L 4331 0 \n",
       "L 3500 0 \n",
       "L 1259 2247 \n",
       "L 1259 0 \n",
       "L 628 0 \n",
       "L 628 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
       "Q 1816 -950 1584 -1140 \n",
       "Q 1353 -1331 966 -1331 \n",
       "L 506 -1331 \n",
       "L 506 -850 \n",
       "L 844 -850 \n",
       "Q 1081 -850 1212 -737 \n",
       "Q 1344 -625 1503 -206 \n",
       "L 1606 56 \n",
       "L 191 3500 \n",
       "L 800 3500 \n",
       "L 1894 763 \n",
       "L 2988 3500 \n",
       "L 3597 3500 \n",
       "L 2059 -325 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
       "L 2834 2853 \n",
       "Q 2591 2978 2328 3040 \n",
       "Q 2066 3103 1784 3103 \n",
       "Q 1356 3103 1142 2972 \n",
       "Q 928 2841 928 2578 \n",
       "Q 928 2378 1081 2264 \n",
       "Q 1234 2150 1697 2047 \n",
       "L 1894 2003 \n",
       "Q 2506 1872 2764 1633 \n",
       "Q 3022 1394 3022 966 \n",
       "Q 3022 478 2636 193 \n",
       "Q 2250 -91 1575 -91 \n",
       "Q 1294 -91 989 -36 \n",
       "Q 684 19 347 128 \n",
       "L 347 722 \n",
       "Q 666 556 975 473 \n",
       "Q 1284 391 1588 391 \n",
       "Q 1994 391 2212 530 \n",
       "Q 2431 669 2431 922 \n",
       "Q 2431 1156 2273 1281 \n",
       "Q 2116 1406 1581 1522 \n",
       "L 1381 1569 \n",
       "Q 847 1681 609 1914 \n",
       "Q 372 2147 372 2553 \n",
       "Q 372 3047 722 3315 \n",
       "Q 1072 3584 1716 3584 \n",
       "Q 2034 3584 2315 3537 \n",
       "Q 2597 3491 2834 3397 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-4b\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"60.576172\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-79\" x=\"122.099609\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"181.279297\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <defs>\n",
       "       <path id=\"m66226f2ad7\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m66226f2ad7\" x=\"34.240625\" y=\"43.2\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(20.878125 46.999219) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m66226f2ad7\" x=\"34.240625\" y=\"54.36\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 1 -->\n",
       "      <g transform=\"translate(20.878125 58.159219) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- Queries -->\n",
       "     <g transform=\"translate(14.798437 68.087031) rotate(-90) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-51\" d=\"M 2522 4238 \n",
       "Q 1834 4238 1429 3725 \n",
       "Q 1025 3213 1025 2328 \n",
       "Q 1025 1447 1429 934 \n",
       "Q 1834 422 2522 422 \n",
       "Q 3209 422 3611 934 \n",
       "Q 4013 1447 4013 2328 \n",
       "Q 4013 3213 3611 3725 \n",
       "Q 3209 4238 2522 4238 \n",
       "z\n",
       "M 3406 84 \n",
       "L 4238 -825 \n",
       "L 3475 -825 \n",
       "L 2784 -78 \n",
       "Q 2681 -84 2626 -87 \n",
       "Q 2572 -91 2522 -91 \n",
       "Q 1538 -91 948 567 \n",
       "Q 359 1225 359 2328 \n",
       "Q 359 3434 948 4092 \n",
       "Q 1538 4750 2522 4750 \n",
       "Q 3503 4750 4090 4092 \n",
       "Q 4678 3434 4678 2328 \n",
       "Q 4678 1516 4351 937 \n",
       "Q 4025 359 3406 84 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-75\" d=\"M 544 1381 \n",
       "L 544 3500 \n",
       "L 1119 3500 \n",
       "L 1119 1403 \n",
       "Q 1119 906 1312 657 \n",
       "Q 1506 409 1894 409 \n",
       "Q 2359 409 2629 706 \n",
       "Q 2900 1003 2900 1516 \n",
       "L 2900 3500 \n",
       "L 3475 3500 \n",
       "L 3475 0 \n",
       "L 2900 0 \n",
       "L 2900 538 \n",
       "Q 2691 219 2414 64 \n",
       "Q 2138 -91 1772 -91 \n",
       "Q 1169 -91 856 284 \n",
       "Q 544 659 544 1381 \n",
       "z\n",
       "M 1991 3584 \n",
       "L 1991 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-51\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-75\" x=\"78.710938\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"142.089844\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"203.613281\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"244.726562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"272.509766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"334.033203\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 34.240625 59.94 \n",
       "L 34.240625 37.62 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 145.840625 59.94 \n",
       "L 145.840625 37.62 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 34.240625 59.94 \n",
       "L 145.840625 59.94 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 34.240625 37.62 \n",
       "L 145.840625 37.62 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       "  <g id=\"axes_2\">\n",
       "   <g id=\"patch_7\">\n",
       "    <path d=\"M 152.815625 90.36 \n",
       "L 156.973625 90.36 \n",
       "L 156.973625 7.2 \n",
       "L 152.815625 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <image xlink:href=\"data:image/png;base64,\n",
       "iVBORw0KGgoAAAANSUhEUgAAAAYAAAB0CAYAAACmJkOCAAAAyElEQVR4nM2WSwrEMAxDXcj9zzqbblp/egG/gEII02WEJMt2Qq+6f2XNN6yyO7dhGQQQo1qLFY9ij30AB5Rz6B7uIqNQKkiKAUwuAyRV8e6SmjTxgAcy5An+qTkD+vrQNVgw15sI5qWXi909kLwCpTCgDPgB83dfuXpVFDAD3vakR18H3KEqWJKJVCBAyTEgSukBw2WP9thsJP0zOAEoFcjoz22EUQ5kyFVhwH5BZx4MkBQ2Ee7mhAEjX2E88sxRCluyYg4MquoD0XoQUGMmzEcAAAAASUVORK5CYII=\" id=\"image3678f8bbd2\" transform=\"scale(1 -1) translate(0 -83.52)\" x=\"152.64\" y=\"-6.48\" width=\"4.32\" height=\"83.52\"/>\n",
       "   <g id=\"matplotlib.axis_3\"/>\n",
       "   <g id=\"matplotlib.axis_4\">\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <defs>\n",
       "       <path id=\"mc7d785d4c5\" d=\"M 0 0 \n",
       "L 3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#mc7d785d4c5\" x=\"156.973625\" y=\"90.36\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.0 -->\n",
       "      <g transform=\"translate(163.973625 94.159219) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
       "L 1344 794 \n",
       "L 1344 0 \n",
       "L 684 0 \n",
       "L 684 794 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mc7d785d4c5\" x=\"156.973625\" y=\"61.294551\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.2 -->\n",
       "      <g transform=\"translate(163.973625 65.09377) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-32\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#mc7d785d4c5\" x=\"156.973625\" y=\"32.229102\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.4 -->\n",
       "      <g transform=\"translate(163.973625 36.028321) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
       "L 825 1625 \n",
       "L 2419 1625 \n",
       "L 2419 4116 \n",
       "z\n",
       "M 2253 4666 \n",
       "L 3047 4666 \n",
       "L 3047 1625 \n",
       "L 3713 1625 \n",
       "L 3713 1100 \n",
       "L 3047 1100 \n",
       "L 3047 0 \n",
       "L 2419 0 \n",
       "L 2419 1100 \n",
       "L 313 1100 \n",
       "L 313 1709 \n",
       "L 2253 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-34\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_1\"/>\n",
       "   <g id=\"patch_8\">\n",
       "    <path d=\"M 152.815625 90.36 \n",
       "L 154.894625 90.36 \n",
       "L 156.973625 90.36 \n",
       "L 156.973625 7.2 \n",
       "L 154.894625 7.2 \n",
       "L 152.815625 7.2 \n",
       "L 152.815625 90.36 \n",
       "z\n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pd589a9b0d6\">\n",
       "   <rect x=\"34.240625\" y=\"37.62\" width=\"111.6\" height=\"22.32\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 250x250 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),\n",
    "                  xlabel='Keys', ylabel='Queries')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62fe2877",
   "metadata": {
    "origin_pos": 52
   },
   "source": [
    "## Summary\n",
    "\n",
    "In this section we introduced the two key attention scoring functions: dot product and additive attention. They are effective tools for aggregating across sequences of variable length. In particular, the dot product attention is the mainstay of modern Transformer architectures. When queries and keys are vectors of different lengths,\n",
    "we can use the additive attention scoring function instead. Optimizing these layers is one of the key areas of advance in recent years. For instance, [NVIDIA's Transformer Library](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html) and Megatron :cite:`shoeybi2019megatron` crucially rely on efficient variants of the attention mechanism. We will dive into this in quite a bit more detail as we review Transformers in later sections. \n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Implement distance-based attention by modifying the `DotProductAttention` code. Note that you only need the squared norms of the keys $\\|\\mathbf{k}_i\\|^2$ for an efficient implementation. \n",
    "1. Modify the dot product attention to allow for queries and keys of different dimensionalities by employing a matrix to adjust dimensions. \n",
    "1. How does the computational cost scale with the dimensionality of the keys, queries, values, and their number? What about the memory bandwidth requirements?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37239f57",
   "metadata": {
    "origin_pos": 54,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/1064)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}